import scipy as sp
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import linalg as splinalg


def run_momentum_batch(A, b,x,n,d,max_iter, batch_size, learning_rate,
                       Delta, loss_history):
  '''
    Generates loss curve of SGD+Momentum with specified batch sizes
    
    Paramaters
    ------------
        INPUTS:
        ---------
        np.array  A               : (mxn)-shape np.array corresponding to input data matrix
        float     b               : target vector
        float     x               : weights/solutions of SGD+M
        float     n               : row size of A
        int       d               : column size of A
        int       max_iter        : number of iterations/updates to weights
        int       batch_size      : size of subset of data to generate stochastic gradient
        float     learning_rate   : step-size/learning_rate paramater
        float     Delta           : momentum parameter
        list      loss_history    : Empty list to be filled with values of loss function of SGD+M
        
        OUTPUT:
        -------
        tuple   (x, loss_history):
                  (1) x             : weights after max_iter number of updates
                  (2) loss_history  : list of loss values of  corresponding to weights in x SGD+M
                                             
  '''

  def loss(x):
    z = A@x - b
    return .5 * np.dot(z,z)
  
  def grad(x):
    B = np.random.choice(range(n), batch_size, replace=False)
    return  A[B,:].T @ (A[B,:]@x-b[B])
  
  t = 0
  x_prev = x

  while t < max_iter:
    loss_history.append(loss(x))
    momentum = Delta * (x - x_prev)
    x_prev = x
    x = x - learning_rate * grad(x) + momentum
    t += 1
  return (x,loss_history)
